{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "597OjogAI3fy"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "5bSCD8SyJC2g"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E_ceEiH7g0MY"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/custom_callback\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />Смотрите на TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ru/guide/keras/custom_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Запустите в Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ru/guide/keras/custom_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />Изучайте код на GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ru/guide/keras/custom_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Скачайте ноутбук</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fj66ZXAzrJC2"
      },
      "source": [
        "Note: Вся информация в этом разделе переведена с помощью русскоговорящего Tensorflow сообщества на общественных началах. Поскольку этот перевод не является официальным, мы не гарантируем что он на 100% аккуратен и соответствует [официальной документации на английском языке](https://www.tensorflow.org/?hl=en). Если у вас есть предложение как исправить этот перевод, мы будем очень рады увидеть pull request в [tensorflow/docs](https://github.com/tensorflow/docs) репозиторий GitHub. Если вы хотите помочь сделать документацию по Tensorflow лучше (сделать сам перевод или проверить перевод подготовленный кем-то другим), напишите нам на [docs-ru@tensorflow.org list](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ru)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1OykC-6lI4gv"
      },
      "source": [
        "# Кастомные колбеки в Keras\n",
        "Кастомный колбек это мощный инструмент для настройки поведения модели Keras во время обучения, оценки или вывода, включая чтение/изменение модели Keras. Примеры включают `tf.keras.callbacks.TensorBoard` где процесс обучения и результаты могут быть экспортированы и визуализированы в TensorBoard, или `tf.keras.callbacks.ModelCheckpoint` где модель автоматически сохраняется во время обучения, и т.д. В этом руководстве вы узнаете, что такое колбек Keras, когда он будет вызван, что он может делать, и как вы можете построить свой колбек. Ближе к концу руководства будет несколько демонстраций создания пары простых колбек-приложений, чтобы помочь вам начать делать собственый колбек."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d5zZ8rZD69VW"
      },
      "source": [
        "## Установка"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7BazS4qD6-2n"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "try:\n",
        "  # %tensorflow_version существует только в Colab.\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0c_TYhQOUe1j"
      },
      "source": [
        "## Введение в колбеки Keras\n",
        "В Keras, `Callback` это базовый класс python предназначенный для наследования и обеспечивающий определенную функциональность. Класс `Callback` имеет предопределнный набор методов, вызываемых на различных этапах обучения (включая начало и конец пакета/эпохи), тестирования и прогнозирования. Колбеки используются для того, чтобы получить представлениее о внутренних состояниях и статистике модели во время обучения. Вы можете передать список колбеков (в качестве ключевого слова аргумента `callbacks`) любому из методов `tf.keras.Model.fit()`, `tf.keras.Model.evaluate()` и `tf.keras.Model.predict()`. Методы колбеков будут вызываться на разных этапах обучения/оценки/вывода.\n",
        "\n",
        "Для начала давайте импортируем tensorflow и определим простую Sequential модель Keras:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ct0VCSI2dt3a"
      },
      "outputs": [],
      "source": [
        "# Определим модель Keras чтобы добавить в нее колбеки\n",
        "def get_model():\n",
        "  model = tf.keras.Sequential()\n",
        "  model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))\n",
        "  model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])\n",
        "  return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ySzdG1IqNgah"
      },
      "source": [
        "Затем загрузим данные MNIST из Keras datasets API для обучения и тестирования:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fwo9LMKGNPWr"
      },
      "outputs": [],
      "source": [
        "# Загрузим данные MNIST и предобработаем их\n",
        "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
        "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n",
        "x_test = x_test.reshape(10000, 784).astype('float32') / 255"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kHVK7kceNqH2"
      },
      "source": [
        "Теперь определим простой колбек который будет просто отслеживать начало и конец каждого пакета данных. Во время этих вызовов он будет печатать индекс текущего пакета."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-dfuGTMINKRR"
      },
      "outputs": [],
      "source": [
        "import datetime\n",
        "\n",
        "class MyCustomCallback(tf.keras.callbacks.Callback):\n",
        "\n",
        "  def on_train_batch_begin(self, batch, logs=None):\n",
        "    print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))\n",
        "\n",
        "  def on_train_batch_end(self, batch, logs=None):\n",
        "    print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))\n",
        "\n",
        "  def on_test_batch_begin(self, batch, logs=None):\n",
        "    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))\n",
        "\n",
        "  def on_test_batch_end(self, batch, logs=None):\n",
        "    print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z4FTUUIBN3WG"
      },
      "source": [
        "Передача колбека в метод модели, например `tf.keras.Model.fit()` гарантирует, что колбек будет вызываться на тех этапах, которые мы в нем определили. Например в примере выше это будет начало/конец тренировочного и тестового пакета(batch)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NJV6Tj3sNGzg"
      },
      "outputs": [],
      "source": [
        "model = get_model()\n",
        "_ = model.fit(x_train, y_train,\n",
        "          batch_size=64,\n",
        "          epochs=1,\n",
        "          steps_per_epoch=5,\n",
        "          verbose=0,\n",
        "          callbacks=[MyCustomCallback()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fIy5JKMlZNmh"
      },
      "source": [
        "## Методы Model работающие с колбеками\n",
        "Пользователи могут передавать список колбеков в следующие методы `tf.keras.Model`:\n",
        "#### [`fit()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit), [`fit_generator()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit_generator)\n",
        "Обучает модель за фиксированное количество эпох (итерации по датасету, или данные полученные по-пакетно с помощью генератора Python).\n",
        "#### [`evaluate()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#evaluate), [`evaluate_generator()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#evaluate_generator)\n",
        "Оценивает модель для имеющихся данных или генератора данных. Выводит значения потерь и метрик во время оценки.\n",
        "#### [`predict()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict), [`predict_generator()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict_generator)\n",
        "Генерирует предсказания для входных данных или генератора данных.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J00bXBbqdnJe"
      },
      "outputs": [],
      "source": [
        "_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5,\n",
        "          callbacks=[MyCustomCallback()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "13n44LVkYQsV"
      },
      "source": [
        "## Обзор методов колбеков\n",
        "\n",
        "\n",
        "### Общие методы для обучения/тестирования/предсказания\n",
        "Для обучения, тестирования и предсказания предоставляются следующие методы для переопределения.\n",
        "#### `on_(train|test|predict)_begin(self, logs=None)`\n",
        "Вызывается в начале `fit`/`evaluate`/`predict`.\n",
        "#### `on_(train|test|predict)_end(self, logs=None)`\n",
        "Вызывается в конце `fit`/`evaluate`/`predict`.\n",
        "\n",
        "### Batch-level методы для обучения/тестирования/предсказания\n",
        "#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)`\n",
        "Вызывается непосредственно перед обработкой пакета во время обучения/тестирования/предсказания. С этим методом, `logs` это словарь с ключами `batch` и `size`, представляющие номер текущего пакета и размер пакета.\n",
        "#### `on_(train|test|predict)_batch_end(self, batch, logs=None)`\n",
        "Вызывается в конце обучения/тестирования/предсказания на пакете. В этом методе `logs` это словарь содержащий результаты метрик.\n",
        "\n",
        "### Epoch-level  методы(используются только в процессе обучения)\n",
        "#### on_epoch_begin(self, epoch, logs=None)\n",
        "Вызывается в начале эпохи во время обучения.\n",
        "#### on_epoch_end(self, epoch, logs=None)\n",
        "Вызывается в конце эпохи во время обучения.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SWf3mXYoceCz"
      },
      "source": [
        "### Использование словаря `logs`\n",
        "Словарь `logs` содержит величину потерь(loss) и все метрики в конце пакета или эпохи. Пример ниже включает величину потерь(loss) и среднеквадратичную ошибку."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u4wIdcF9BjJH"
      },
      "outputs": [],
      "source": [
        "class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):\n",
        "\n",
        "  def on_train_batch_end(self, batch, logs=None):\n",
        "    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))\n",
        "\n",
        "  def on_test_batch_end(self, batch, logs=None):\n",
        "    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))\n",
        "\n",
        "  def on_epoch_end(self, epoch, logs=None):\n",
        "    print('Средние потери за эпоху {} равны {:7.2f}, а среднеквадратичная ошибка равна {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))\n",
        "\n",
        "model = get_model()\n",
        "_ = model.fit(x_train, y_train,\n",
        "          batch_size=64,\n",
        "          steps_per_epoch=5,\n",
        "          epochs=3,\n",
        "          verbose=0,\n",
        "          callbacks=[LossAndErrorPrintingCallback()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LbXqvC8FHqeu"
      },
      "source": [
        "Аналогично можно использовать колбеки в вызове `evaluate()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jzTKYPQHwcxF"
      },
      "outputs": [],
      "source": [
        "_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=20,\n",
        "          callbacks=[LossAndErrorPrintingCallback()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HnSljqtsXKfb"
      },
      "source": [
        "## Примеры колбек-модулей Keras\n",
        "Следующий раздел поможет вам в создании простых  Callback модулей."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kptNF0--Lznv"
      },
      "source": [
        "### Ранняя остановка при минимальном значении функции потерь\n",
        "Первый пример демонстрирует создание `Callback`а который останавливает обучение Keras когда достигнута минимумальная величина потерь, путем изменения аргумента `model.stop_training` (булево значение). Опционально пользователь может использовать аргумент `patience`, чтобы указать сколько дополнительных эпох обучаться до полной остановки при получении минимального значения потерь.\n",
        "\n",
        "`tf.keras.callbacks.EarlyStopping` обеспечиваает более полную и общую реализацию."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BM31gfAV4mks"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "class EarlyStoppingAtMinLoss(tf.keras.callbacks.Callback):\n",
        "  \"\"\"Остановить обучение, когда loss на минимуме, т.е. loss прекратил уменьшаться.\n",
        "\n",
        "  Аргументы:\n",
        "      patience: Количество эпох ожидания после достижения минимума. Если столько\n",
        "      эпох нет улучшения, обучение останавливается.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, patience=0):\n",
        "    super(EarlyStoppingAtMinLoss, self).__init__()\n",
        "\n",
        "    self.patience = patience\n",
        "\n",
        "    # best_weights для хранения весов на которых достигнут минимум потерь.\n",
        "    self.best_weights = None\n",
        "\n",
        "  def on_train_begin(self, logs=None):\n",
        "    # Количество эпох ожидания после достижения минимума.\n",
        "    self.wait = 0\n",
        "    # Эпоха на которой остановилось обучение.\n",
        "    self.stopped_epoch = 0\n",
        "    # Устанавливаем self.best равным бесконечности.\n",
        "    self.best = np.Inf\n",
        "\n",
        "  def on_epoch_end(self, epoch, logs=None):\n",
        "    current = logs.get('loss')\n",
        "    if np.less(current, self.best):\n",
        "      self.best = current\n",
        "      self.wait = 0\n",
        "      # Записать лучшие веса если текущая величина потерь меньше значения в self.best.\n",
        "      self.best_weights = self.model.get_weights()\n",
        "    else:\n",
        "      self.wait += 1\n",
        "      if self.wait >= self.patience:\n",
        "        self.stopped_epoch = epoch\n",
        "        self.model.stop_training = True\n",
        "        print('Восстановление весов модели с конца лучшей эпохи.')\n",
        "        self.model.set_weights(self.best_weights)\n",
        "\n",
        "  def on_train_end(self, logs=None):\n",
        "    if self.stopped_epoch > 0:\n",
        "      print('Эпоха %05d: ранняя остановка' % (self.stopped_epoch + 1))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xS4fa-7PFzzc"
      },
      "outputs": [],
      "source": [
        "model = get_model()\n",
        "_ = model.fit(x_train, y_train,\n",
        "          batch_size=64,\n",
        "          steps_per_epoch=5,\n",
        "          epochs=30,\n",
        "          verbose=0,\n",
        "          callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SpVDjs_Dkkdh"
      },
      "source": [
        "### Планирование скорости обучения\n",
        "\n",
        "При обучении модели часто используют изменение скорости обучения по мере того, как проходит больше эпох.\n",
        "\n",
        "Примечание: это лишь реализация примера, для более глубокого понимания смотрите `callbacks.LearningRateScheduler` и `keras.optimizers.schedules`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PGowEUC8klSz"
      },
      "outputs": [],
      "source": [
        "class LearningRateScheduler(tf.keras.callbacks.Callback):\n",
        "  \"\"\"Планировщик скорости обучения, устанавливающий скорость в соответствии с расписанием.\n",
        "\n",
        "  Аргументы:\n",
        "      schedule: функция которая получает на вход индекс эпохи\n",
        "          (целое число, индексируемое с 0) и текущую скорость обучения\n",
        "          и возвращает на выходе новую скорость обучения(float).\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, schedule):\n",
        "    super(LearningRateScheduler, self).__init__()\n",
        "    self.schedule = schedule\n",
        "\n",
        "  def on_epoch_begin(self, epoch, logs=None):\n",
        "    if not hasattr(self.model.optimizer, 'lr'):\n",
        "      raise ValueError('Optimizer must have a \"lr\" attribute.')\n",
        "    # Получаем текущую скорость обучения из оптимизатора модели.\n",
        "    lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))\n",
        "    # Вызываем функцию расписания, чтобы получить запланированную скорость обучения.\n",
        "    scheduled_lr = self.schedule(epoch, lr)\n",
        "    # Устанавливаем новое значение в оптимизатор до начала новой эпохи\n",
        "    tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)\n",
        "    print('\\nEpoch %05d: Learning rate is %6.4f.' % (epoch, scheduled_lr))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1jL3pI5Ep5F8"
      },
      "outputs": [],
      "source": [
        "LR_SCHEDULE = [\n",
        "    # (epoch to start, learning rate) tuples\n",
        "    (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001)\n",
        "]\n",
        "\n",
        "def lr_schedule(epoch, lr):\n",
        "  \"\"\"Вспомогательная функция для получения запланированной скорости обучения на основе порядкового номера эпохи.\"\"\"\n",
        "  if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:\n",
        "    return lr\n",
        "  for i in range(len(LR_SCHEDULE)):\n",
        "    if epoch == LR_SCHEDULE[i][0]:\n",
        "      return LR_SCHEDULE[i][1]\n",
        "  return lr\n",
        "\n",
        "model = get_model()\n",
        "_ = model.fit(x_train, y_train,\n",
        "          batch_size=64,\n",
        "          steps_per_epoch=5,\n",
        "          epochs=15,\n",
        "          verbose=0,\n",
        "          callbacks=[LossAndErrorPrintingCallback(), LearningRateScheduler(lr_schedule)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9xMkm699JzK8"
      },
      "source": [
        "### Стандартные колбеки Keras\n",
        "Узнать больше о колбеках Keras вы можете [прочитав полную документацию по API](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks). Там представлены такие колбеки как логирование в CSV, сохранение модели, визуализацию на TensorBoard и многое другое."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "9xMkm699JzK8"
      ],
      "name": "custom_callback.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
